{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ur8xi4C7S06n"
      },
      "outputs": [],
      "source": [
        "# Copyright 2024 Google LLC\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JAPoU8Sm5E6e"
      },
      "source": [
        "# Hugging Face DLCs: Fine-tuning Gemma with Transformer Reinforcement Learning (TRL) on Vertex AI\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/open-models/fine-tuning/vertex_ai_trl_fine_tuning_gemma.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg\" alt=\"Google Colaboratory logo\"><br> Open in Colab\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fopen-models%2Ffine-tuning%2Fvertex_ai_trl_fine_tuning_gemma.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN\" alt=\"Google Cloud Colab Enterprise logo\"><br> Open in Colab Enterprise\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/open-models/fine-tuning/vertex_ai_trl_fine_tuning_gemma.ipynb\">\n",
        "      <img src=\"https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg\" alt=\"Vertex AI logo\"><br> Open in Vertex AI Workbench\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/fine-tuning/vertex_ai_trl_fine_tuning_gemma.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://raw.githubusercontent.com/primer/octicons/refs/heads/main/icons/mark-github-24.svg\" alt=\"GitHub logo\"><br> View on GitHub\n",
        "    </a>\n",
        "  </td>\n",
        "</table>\n",
        "\n",
        "<div style=\"clear: both;\"></div>\n",
        "\n",
        "<b>Share to:</b>\n",
        "\n",
        "<a href=\"https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/fine-tuning/vertex_ai_trl_fine_tuning_gemma.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg\" alt=\"LinkedIn logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/fine-tuning/vertex_ai_trl_fine_tuning_gemma.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg\" alt=\"Bluesky logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/fine-tuning/vertex_ai_trl_fine_tuning_gemma.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg\" alt=\"X logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/fine-tuning/vertex_ai_trl_fine_tuning_gemma.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" alt=\"Reddit logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/fine-tuning/vertex_ai_trl_fine_tuning_gemma.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg\" alt=\"Facebook logo\">\n",
        "</a>            "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "83b98b0ba19c"
      },
      "source": [
        "![Hugging Face x Google Cloud](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/google-cloud/thumbnail.png)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "84f0f73a0f76"
      },
      "source": [
        "| | |\n",
        "|-|-|\n",
        "| Author(s) | [Alvaro Bartolome](https://github.com/alvarobartt), [Philipp Schmid](https://github.com/philschmid), [Simon Pagezy](https://github.com/pagezyhf), and [Jeff Boudier](https://github.com/jeffboudier) |"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tvgnzT1CKxrO"
      },
      "source": [
        "## Overview"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3608b194054c"
      },
      "source": [
        "> [**Gemma**](https://ai.google.dev/gemma) is a family of lightweight, state-of-the-art open models built from the same research and technology used to create the Gemini models, developed by Google DeepMind and other teams across Google.\n",
        "\n",
        "> [**Transformer Reinforcement Learning (TRL)**](https://github.com/huggingface/trl) is a framework developed by Hugging Face to fine-tune and align both transformer language and diffusion models using methods such as Supervised Fine-Tuning (SFT), Reward Modeling (RM), Proximal Policy Optimization (PPO), Direct Preference Optimization (DPO), and others.\n",
        "\n",
        "> [**Hugging Face DLCs**](https://github.com/huggingface/Google-Cloud-Containers) are pre-built and optimized Deep Learning Containers (DLCs) maintained by Hugging Face and Google Cloud teams to simplify environment configuration for your ML workloads.\n",
        "\n",
        "> [**Google Vertex AI**](https://cloud.google.com/vertex-ai) is a Machine Learning (ML) platform that lets you train and deploy ML models and AI applications, and customize large language models (LLMs) for use in your AI-powered applications.\n",
        "\n",
        "This notebook showcases how to fine-tune Google Gemma with Supervised Fine-Tuning (SFT) and Low-Rank Adaptation (LoRA) in a single GPU via a custom job on Vertex AI using the Hugging Face PyTorch Deep Learning Container (DLC) for Training.\n",
        "\n",
        "By the end of this notebook, you will learn:\n",
        "\n",
        "- About Hugging Face's TRL and LLM fine-tuning\n",
        "- How to create and run a custom container job on Vertex AI"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "61RBz8LLbxCR"
      },
      "source": [
        "## Get started"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "No17Cw5hgx12"
      },
      "source": [
        "### Install Vertex AI SDK and other required packages\n",
        "\n",
        "To run this example, you will only need the [`google-cloud-aiplatform`](https://github.com/googleapis/python-aiplatform) Python SDK and the [`huggingface_hub`](https://github.com/huggingface/huggingface_hub) Python package."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tFy3H3aPgx12"
      },
      "outputs": [],
      "source": [
        "%pip install --upgrade --user --quiet google-cloud-aiplatform huggingface_hub"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "R5Xep4W9lq-Z"
      },
      "source": [
        "### Restart runtime (Colab only)\n",
        "\n",
        "To use the newly installed packages in this Jupyter environment, if you are on Colab you must restart the runtime. You can do this by running the cell below, which restarts the current kernel. The restart might take a minute or longer. After it's restarted, continue to the next step."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XRvKdaPDTznN"
      },
      "outputs": [],
      "source": [
        "# Automatically restart kernel after installs so that your environment can access the new packages\n",
        "# import IPython\n",
        "\n",
        "# app = IPython.Application.instance()\n",
        "# app.kernel.do_shutdown(True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SbmM4z7FOBpM"
      },
      "source": [
        "<div class=\"alert alert-block alert-warning\">\n",
        "<b>⚠️ The kernel is going to restart. Wait until it's finished before continuing to the next step. ⚠️</b>\n",
        "</div>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dmWOrTJ3gx13"
      },
      "source": [
        "### Authenticate your Google Cloud account\n",
        "\n",
        "Depending on your Jupyter environment, you may have to manually authenticate. Follow the relevant instructions below."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9cc332aee5fd"
      },
      "source": [
        "**1. Vertex AI Workbench**\n",
        "\n",
        "* Do nothing as you are already authenticated."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "54c81e1a63e6"
      },
      "source": [
        "**2. Local JupyterLab instance, uncomment and run:**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c93d96bc9c43"
      },
      "outputs": [],
      "source": [
        "# !gcloud auth login"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d3e571ce6c56"
      },
      "source": [
        "**3. Colab, uncomment and run:**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "984a0526fb68"
      },
      "outputs": [],
      "source": [
        "# from google.colab import auth\n",
        "# auth.authenticate_user()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f9af3e57f89a"
      },
      "source": [
        "### Authenticate your Hugging Face account\n",
        "\n",
        "As [`google/gemma-2b`](https://huggingface.co/google/gemma-2b) is a gated model, you need to have a Hugging Face Hub account, and accept the Google's usage license for Gemma. Once that's done, you need to generate a new user access token with read-only access so that the weights can be downloaded from the Hub in the Hugging Face DLC for TGI.\n",
        "\n",
        "> Note that the user access token can only be generated via [the Hugging Face Hub UI](https://huggingface.co/settings/tokens/new), where you can either select read-only access to your account, or follow the recommendations and generate a fine-grained token with read-only access to [`google/gemma-2b`](https://huggingface.co/google/gemma-2b)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4c31c7272804"
      },
      "source": [
        "Then you can install the `huggingface_hub` that comes with a CLI that will be used for the authentication with the token generated in advance. So that then the token can be safely retrieved via `huggingface_hub.get_token`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8d836e0210fe"
      },
      "outputs": [],
      "source": [
        "from huggingface_hub import interpreter_login\n",
        "\n",
        "interpreter_login()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c71a4314c250"
      },
      "source": [
        "Read more about [Hugging Face Security](https://huggingface.co/docs/hub/en/security), specifically about [Hugging Face User Access Tokens](https://huggingface.co/docs/hub/en/security-tokens)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DF4l8DTdWgPY"
      },
      "source": [
        "### Set Google Cloud environment variables"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c14bf7a1ce9e"
      },
      "source": [
        "You will need to set the following environment variables that are required to run this example."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Nqwi-5ufWp_B"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "\n",
        "PROJECT_ID = \"[your-project-id]\"  # @param {type:\"string\", isTemplate: true}\n",
        "if PROJECT_ID == \"[your-project-id]\":\n",
        "    PROJECT_ID = str(os.environ.get(\"GOOGLE_CLOUD_PROJECT\"))\n",
        "os.environ[\"PROJECT_ID\"] = PROJECT_ID\n",
        "\n",
        "os.environ[\"LOCATION\"] = os.environ.get(\"GOOGLE_CLOUD_REGION\", \"us-central1\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d03b04d07dc0"
      },
      "source": [
        "### (Optional) Create a bucket in Google Cloud Storage (GCS)\n",
        "\n",
        "As the custom job on Vertex AI will need to dump the fine-tuned artifacts somewhere on Google Cloud, you will need to have a bucket available on Google Cloud Storage (GCS), so that you can specify it as the bucket to be used within the custom job, so that anything written within the container in that path is automatically uploaded to GCS."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0ca9b8a1bf63"
      },
      "outputs": [],
      "source": [
        "BUCKET_URI = \"[your-bucket-uri]\"  # @param {type:\"string\", isTemplate: true}\n",
        "if BUCKET_URI == \"[your-bucket-uri]\":\n",
        "    raise ValueError(\n",
        "        \"A valid BUCKET_URI (e.g. `gs://path/to/bucket`) needs to be specified\"\n",
        "    )\n",
        "os.environ[\"BUCKET_URI\"] = BUCKET_URI"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c1a68956ee21"
      },
      "source": [
        "> Uncomment the `gcloud storage buckets create` command below if you need to create a bucket on GCS."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2ff73fc11e33"
      },
      "outputs": [],
      "source": [
        "# !gcloud storage buckets create $BUCKET_URI --project $PROJECT_ID --location=$LOCATION --default-storage-class=STANDARD --uniform-bucket-level-access"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4d527513fda4"
      },
      "source": [
        "### Initialize Vertex AI SDK\n",
        "\n",
        "To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com), if not enabled already.\n",
        "\n",
        "Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "865070bd8875"
      },
      "outputs": [],
      "source": [
        "aiplatform.init(\n",
        "    project=os.getenv(\"PROJECT_ID\"),\n",
        "    location=os.getenv(\"LOCATION\"),\n",
        "    staging_bucket=os.getenv(\"BUCKET_URI\"),\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ee37e1544281"
      },
      "source": [
        "## Requirements"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "877cd3fb2dce"
      },
      "source": [
        "You will need to have the following IAM roles set:\n",
        "\n",
        "- Artifact Registry Reader (roles/artifactregistry.reader)\n",
        "- Vertex AI User (roles/aiplatform.user)\n",
        "- Storage Object Creator (roles/storage.objectCreator)\n",
        "\n",
        "For more information about granting roles, see [Manage access](https://cloud.google.com/iam/docs/granting-changing-revoking-access).\n",
        "\n",
        "---\n",
        "\n",
        "You will also need to enable the following APIs (if not enabled already):\n",
        "\n",
        "- Vertex AI API (aiplatform.googleapis.com)\n",
        "- Cloud Storage API (storage-api.googleapis.com)\n",
        "- Artifact Registry API (artifactregistry.googleapis.com)\n",
        "\n",
        "For more information about API enablement, see [Enabling APIs](https://cloud.google.com/apis/docs/getting-started#enabling_apis).\n",
        "\n",
        "---\n",
        "\n",
        "You will also need to have a Google Cloud Storage (GCS) bucket where the custom job can read and write artifacts.\n",
        "\n",
        "---\n",
        "\n",
        "To access Gemma on Hugging Face, you're required to review and agree to Google's usage license on the Hugging Face Hub for any of the models from the [Gemma release collection](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b), and the access request will be processed immediately."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EdvJRUWRNGHE"
      },
      "source": [
        "## Create a `CustomContainerTrainingJob` on Vertex AI"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2d1505636843"
      },
      "source": [
        "You need to define a `CustomContainerTrainingJob` that runs on top of the Hugging Face PyTorch DLC for Training.\n",
        "\n",
        "The Hugging Face PyTorch DCL for Training comes with most of the Hugging Face Python libraries installed, so as to provide a seamless environment to use the Hugging Face stack on Google Cloud. In this example, as already mentioned, [`trl`](https://github.com/huggingface/trl) will be used to run the Supervised Fine-Tuning (SFT) on top of [`transformers`](https://github.com/huggingface/transformers) for the modelling, [`peft`](https://github.com/huggingface/peft) for the LoRA support and [`bitsandbytes`](https://github.com/huggingface/bitsandbytes) for the quantization support.\n",
        "\n",
        "Since the Hugging Face PyTorch DLC for Training does not have any pre-defined `CMD` or `ENTRYPOINT`, you will need to set it to `trl sft`; which is the TRL command to run the SFT fine-tuning via the recently released [TRL CLI](https://huggingface.co/docs/trl/en/clis)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f223713c7111"
      },
      "outputs": [],
      "source": [
        "job = aiplatform.CustomContainerTrainingJob(\n",
        "    display_name=\"gemma-2b-sft-lora\",\n",
        "    container_uri=\"us-docker.pkg.dev/deeplearning-platform-release/gcr.io/huggingface-pytorch-training-cu121.2-3.transformers.4-42.ubuntu2204.py310\",\n",
        "    command=[\"trl\", \"sft\"],\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "70ccc1885fbb"
      },
      "source": [
        "## Prepare `CustomContainerTrainingJob` on Vertex AI"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "61f2d20694d6"
      },
      "source": [
        "Before submitting the `CustomContainerTrainingJob` you need to define the following:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ae5cb15e64c2"
      },
      "source": [
        "### Model to fine-tune\n",
        "\n",
        "In this case you will be fine-tuning a base model, [`google/gemma-2b`](https://huggingface.co/google/gemma-2b), meaning that the model by itself will just generate text and has not been fined-tuned in advance i.e. no prompt templates, etc."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c914172f34d9"
      },
      "source": [
        "### Model requirements\n",
        "\n",
        "Fine-tuning a model can be expensive, you can estimate that to full fine-tune an LLM in half precision you would need around four times the disk size of the LLM in GPU VRAM.\n",
        "\n",
        "For a 2B LLM that takes around 5GiB on disk, it could translate into 20GiB for the half precision fine-tuning; but that could be reduced further by using a quantized optimizer, using LoRA or QLoRA, reducing the batch size, etc. In this case, you will be using an L4 NVIDIA GPU that comes with 24GiB of VRAM so both SFT or SFT + LoRA will work out of the box as those fit within the 24GiB VRAM limit.\n",
        "\n",
        "Read more about it in [Eleuther AI - Transformer Math 101](https://blog.eleuther.ai/transformer-math/)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8dc06eda3134"
      },
      "source": [
        "### Dataset to fine-tune on\n",
        "\n",
        "SFT expects either a text-only, a conversational, or a prompt-completion dataset; in this case, a text-only dataset [`timdettmers/openassistant-guanaco`](https://huggingface.co/datasets/timdettmers/openassistant-guanaco) will be used.\n",
        "\n",
        "This dataset is a subset of the [Open Assistant dataset](https://huggingface.co/datasets/OpenAssistant/oasst1); and this subset only contains the highest-rated paths in the conversation tree, with a total of 9,846 samples."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0ad5dcdcde93"
      },
      "source": [
        "## Submit `CustomContainerTrainingJob` on Vertex AI"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7bcc72c900e5"
      },
      "source": [
        "As the `CustomContainerTrainingJob` defines the command `trl sft` the arguments to be provided are listed either in the Python reference at [trl.SFTConfig](https://huggingface.co/docs/trl/en/sft_trainer#trl.SFTConfig) or via the `trl sft --help` command.\n",
        "\n",
        "Read more about the [TRL CLI](https://huggingface.co/docs/trl/en/clis)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "af41d90a8a96"
      },
      "outputs": [],
      "source": [
        "args = [\n",
        "    # MODEL\n",
        "    \"--model_name_or_path=google/gemma-2b\",\n",
        "    \"--torch_dtype=bfloat16\",\n",
        "    \"--attn_implementation=flash_attention_2\",\n",
        "    # DATASET\n",
        "    \"--dataset_name=timdettmers/openassistant-guanaco\",\n",
        "    \"--dataset_text_field=text\",\n",
        "    # PEFT\n",
        "    \"--use_peft\",\n",
        "    \"--lora_r=16\",\n",
        "    \"--lora_alpha=32\",\n",
        "    \"--lora_dropout=0.1\",\n",
        "    \"--lora_target_modules=all-linear\",\n",
        "    # TRAINER\n",
        "    \"--bf16\",\n",
        "    \"--max_seq_length=1024\",\n",
        "    \"--per_device_train_batch_size=8\",\n",
        "    \"--gradient_accumulation_steps=4\",\n",
        "    \"--gradient_checkpointing\",\n",
        "    \"--learning_rate=0.0002\",\n",
        "    \"--lr_scheduler_type=cosine\",\n",
        "    \"--optim=adamw_bnb_8bit\",\n",
        "    \"--num_train_epochs=1\",\n",
        "    \"--logging_steps=10\",\n",
        "    \"--do_eval\",\n",
        "    \"--eval_steps=100\",\n",
        "    \"--report_to=none\",\n",
        "    f\"--output_dir={os.getenv('BUCKET_URI').replace('gs://', '/gcs/')}/gemma-2b-sft-lora\",\n",
        "    \"--overwrite_output_dir\",\n",
        "    \"--seed=42\",\n",
        "    \"--log_level=debug\",\n",
        "]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "77cdb9afabbf"
      },
      "source": [
        "> It's important to note that since GCS FUSE is used to mount the bucket as a directory within the running container job, the mounted path follows the formatting `/gcs/<BUCKET_NAME>`; meaning that anything the `SFTTrainer` writes there will be automatically uploaded to the GCS Bucket.\n",
        ">\n",
        "> More information in the [Vertex AI Documentation - Prepare training code](https://cloud.google.com/vertex-ai/docs/training/code-requirements)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3bc2dce404dc"
      },
      "source": [
        "Then you need to call the `submit` method on the `aiplatform.CustomContainerTrainingJob`, which is a non-blocking method that will schedule the job without blocking the execution.\n",
        "\n",
        "The arguments provided to the `submit` method are listed below:\n",
        "\n",
        "* **`args`** defines the list of arguments to be provided to the `trl sft` command, provided as `trl sft --arg_1=value ...`.\n",
        "\n",
        "* **`replica_count`** defines the number of replicas to run the job in, for training normally this value will be set to one.\n",
        "\n",
        "* **`machine_type`**, **`accelerator_type`** and **`accelerator_count`** define the machine i.e. Compute Engine instance, the accelerator (if any), and the number of accelerators (ranging from 1 to 8); respectively. The `machine_type` and the `accelerator_type` are tied together, so you will need to select an instance that supports the accelerator that you are using and vice-versa. More information about the different instances at [Compute Engine Documentation - GPU machine types](https://cloud.google.com/compute/docs/gpus), and about the `accelerator_type` naming at [Vertex AI Documentation - MachineSpec](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/MachineSpec).\n",
        "\n",
        "* **`base_output_dir`** defines the base directory that will be mounted within the running container from the GCS Bucket, conditioned by the `staging_bucket` argument provided to the `aiplatform.init` initially.\n",
        "\n",
        "* (optional) **`environment_variables`** defines the environment variables to define within the running container. As you are fine-tuning a gated model i.e. [`google/gemma-2b`](https://huggingface.co/google/gemma-2b), you need to set the `HF_TOKEN` environment variable. Additionally, some other environment variables are defined to set the cache path (`HF_HOME`) and to ensure that the logging messages are streamed to Google Cloud Logs Explorer properly (`TRL_USE_RICH`, `ACCELERATE_LOG_LEVEL`, `TRANSFORMERS_LOG_LEVEL`, and `TQDM_POSITION`).\n",
        "\n",
        "* (optional) **`timeout`** and **`create_request_timeout`** define the timeouts in seconds before interrupting the job execution or the job creation request (time to allocate required resources and start the execution), respectively."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0e685d739cc6"
      },
      "outputs": [],
      "source": [
        "from huggingface_hub import get_token\n",
        "\n",
        "job.submit(\n",
        "    args=args,\n",
        "    replica_count=1,\n",
        "    machine_type=\"g2-standard-12\",\n",
        "    accelerator_type=\"NVIDIA_L4\",\n",
        "    accelerator_count=1,\n",
        "    base_output_dir=f\"{os.getenv('BUCKET_URI')}/gemma-2b-sft-lora\",\n",
        "    environment_variables={\n",
        "        \"HF_HOME\": \"/root/.cache/huggingface\",\n",
        "        \"HF_TOKEN\": get_token(),\n",
        "        \"TRL_USE_RICH\": \"0\",\n",
        "        \"ACCELERATE_LOG_LEVEL\": \"INFO\",\n",
        "        \"TRANSFORMERS_LOG_LEVEL\": \"INFO\",\n",
        "        \"TQDM_POSITION\": \"-1\",\n",
        "    },\n",
        "    timeout=60 * 60 * 3,  # 3 hours (10800s)\n",
        "    create_request_timeout=60 * 10,  # 10 minutes (600s)\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "da780e6ea3ff"
      },
      "source": [
        "> The `CustomContainerTrainingJob` will run asynchronously, meaning that the execution won't be blocked and the job will automatically allocate and deallocate the required resources. So once triggered, you can go to Vertex AI in the Google Cloud Console to monitor the job closely."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1a39697ebbdf"
      },
      "source": [
        "## What's next?"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1b87bbe001c8"
      },
      "source": [
        "Once the fine-tuning is done, you can already deploy it!\n",
        "\n",
        "To deploy it, you can either [merge the LoRA adapters into the fine-tuned Google Gemma model](https://huggingface.co/docs/peft/en/developer_guides/model_merging) in advance, or just [deploy the base model and the adapter separately](https://huggingface.co/blog/multi-lora-serving).\n",
        "\n",
        "As deployment and serving is out of the scope of this example, you can refer to the following examples in [`GoogleCloudPlatform/generative-ai`](https://github.com/GoogleCloudPlatform/generative-ai):\n",
        "\n",
        "- [Hugging Face DLCs: Serving Gemma with Text Generation Inference (TGI) on Vertex AI](./vertex_ai_text_generation_inference_gemma.ipynb)\n",
        "\n",
        "Or to the examples available within the [`huggingface/Google-Cloud-Containers`](https://github.com/huggingface/Google-Cloud-Containers) repository:\n",
        "\n",
        "- [Deploy Gemma 7B with TGI on Vertex AI](https://github.com/huggingface/Google-Cloud-Containers/blob/main/examples/vertex-ai/notebooks/deploy-gemma-on-vertex-ai/vertex-notebook.ipynb)\n",
        "- [Deploy Gemma 7B from GCS with TGI on Vertex AI](https://github.com/huggingface/Google-Cloud-Containers/blob/main/examples/vertex-ai/notebooks/deploy-gemma-from-gcs-on-vertex-ai/vertex-notebook.ipynb)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "392ded651e9e"
      },
      "source": [
        "## References"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "137d82da56f2"
      },
      "source": [
        "- [GitHub Repository - Hugging Face DLCs for Google Cloud](https://github.com/huggingface/Google-Cloud-Containers): contains all the containers developed by the collaboration of both Hugging Face and Google Cloud teams; as well as a lot of examples on both training and inference, covering both CPU and GPU, as well support for most of the models within the Hugging Face Hub.\n",
        "- [Google Cloud Documentation - Hugging Face DLCs](https://cloud.google.com/deep-learning-containers/docs/choosing-container#hugging-face): contains a table with the latest released Hugging Face DLCs on Google Cloud.\n",
        "- [Google Artifact Registry - Hugging Face DLCs](https://console.cloud.google.com/artifacts/docker/deeplearning-platform-release/us/gcr.io): contains all the DLCs released by Google Cloud that can be used.\n",
        "- [Hugging Face Documentation - Google Cloud](https://huggingface.co/docs/google-cloud): contains the official Hugging Face documentation for the Google Cloud DLCs."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "vertex_ai_trl_fine_tuning_gemma.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
